#' To run this example:
#' 
#'  1) Download the train.zip and train_masks.zip files from: 
#'     https://www.kaggle.com/c/carvana-image-masking-challenge/data
#' 
#'  2) Create an "input" directory and extract the zip files into it (after this there
#'     should be "train" and "train_masks" subdirectories within the "input" directory).
#'

#` This code runs only on Windows because of specific parallel backend. 
#` You can find Linux version here: https://keras.rstudio.com/articles/examples/unet_linux.html
#` unet architecture is based on original Python code 
#` from https://github.com/petrosgk/Kaggle-Carvana-Image-Masking-Challenge. 
#` It shows an example of creating custom architectures in R version of keras 
#` and working with images using magick package. 
#` parallel + doParallel + foreach allows to speed up the code.
#` You can download the data from https://www.kaggle.com/c/carvana-image-masking-challenge

library(keras)
library(magick)
library(abind)
library(reticulate)
library(parallel)
library(doParallel)
library(foreach)


# Parameters -----------------------------------------------------

input_size <- 128

epochs <- 30
batch_size <- 16

orig_width <- 1918
orig_height <- 1280

threshold <- 0.5

train_samples <- 5088
train_index <- sample(1:train_samples, round(train_samples * 0.8)) # 80%
val_index <- c(1:train_samples)[-train_index]

images_dir <- "input/train/" 
masks_dir <- "input/train_masks/"


# Loss function -----------------------------------------------------

K <- backend()

dice_coef <- function(y_true, y_pred, smooth = 1.0) {
    y_true_f <- k_flatten(y_true)
    y_pred_f <- k_flatten(y_pred)
    intersection <- k_sum(y_true_f * y_pred_f)
    result <- (2 * intersection + smooth) / 
        (k_sum(y_true_f) + k_sum(y_pred_f) + smooth)
    return(result)
}

bce_dice_loss <- function(y_true, y_pred) {
    result <- loss_binary_crossentropy(y_true, y_pred) +
        (1 - dice_coef(y_true, y_pred))
    return(result)
}


# U-net 128 -----------------------------------------------------

get_unet_128 <- function(input_shape = c(128, 128, 3),
                         num_classes = 1) {
    
    inputs <- layer_input(shape = input_shape)
    # 128
    
    down1 <- inputs %>%
        layer_conv_2d(filters = 64, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") %>%
        layer_conv_2d(filters = 64, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") 
    down1_pool <- down1 %>%
        layer_max_pooling_2d(pool_size = c(2, 2), strides = c(2, 2))
    # 64
    
    down2 <- down1_pool %>%
        layer_conv_2d(filters = 128, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") %>%
        layer_conv_2d(filters = 128, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") 
    down2_pool <- down2 %>%
        layer_max_pooling_2d(pool_size = c(2, 2), strides = c(2, 2))
    # 32
    
    down3 <- down2_pool %>%
        layer_conv_2d(filters = 256, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") %>%
        layer_conv_2d(filters = 256, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") 
    down3_pool <- down3 %>%
        layer_max_pooling_2d(pool_size = c(2, 2), strides = c(2, 2))
    # 16
    
    down4 <- down3_pool %>%
        layer_conv_2d(filters = 512, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") %>%
        layer_conv_2d(filters = 512, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") 
    down4_pool <- down4 %>%
        layer_max_pooling_2d(pool_size = c(2, 2), strides = c(2, 2))
    # 8
    
    center <- down4_pool %>%
        layer_conv_2d(filters = 1024, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") %>%
        layer_conv_2d(filters = 1024, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") 
    # center
    
    up4 <- center %>%
        layer_upsampling_2d(size = c(2, 2)) %>%
        {layer_concatenate(inputs = list(down4, .), axis = 3)} %>%
        layer_conv_2d(filters = 512, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") %>%
        layer_conv_2d(filters = 512, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") %>%
        layer_conv_2d(filters = 512, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu")
    # 16
    
    up3 <- up4 %>%
        layer_upsampling_2d(size = c(2, 2)) %>%
        {layer_concatenate(inputs = list(down3, .), axis = 3)} %>%
        layer_conv_2d(filters = 256, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") %>%
        layer_conv_2d(filters = 256, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") %>%
        layer_conv_2d(filters = 256, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu")
    # 32
    
    up2 <- up3 %>%
        layer_upsampling_2d(size = c(2, 2)) %>%
        {layer_concatenate(inputs = list(down2, .), axis = 3)} %>%
        layer_conv_2d(filters = 128, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") %>%
        layer_conv_2d(filters = 128, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") %>%
        layer_conv_2d(filters = 128, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu")
    # 64
    
    up1 <- up2 %>%
        layer_upsampling_2d(size = c(2, 2)) %>%
        {layer_concatenate(inputs = list(down1, .), axis = 3)} %>%
        layer_conv_2d(filters = 64, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") %>%
        layer_conv_2d(filters = 64, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu") %>%
        layer_conv_2d(filters = 64, kernel_size = c(3, 3), padding = "same") %>%
        layer_batch_normalization() %>%
        layer_activation("relu")
    # 128
    
    classify <- layer_conv_2d(up1,
                              filters = num_classes, 
                              kernel_size = c(1, 1),
                              activation = "sigmoid")
    
    
    model <- keras_model(
        inputs = inputs,
        outputs = classify
    )
    
    model %>% compile(
        optimizer = optimizer_rmsprop(lr = 0.0001),
        loss = bce_dice_loss,
        metrics = c(dice_coef)
    )
    
    return(model)
}

model <- get_unet_128()


# Read and augmentation functions -----------------------------------------------------

imagesRead <- function(image_file,
                       mask_file,
                       target_width = 128, 
                       target_height = 128) {
    img <- image_read(image_file)
    img <- image_scale(img, paste0(target_width, "x", target_height, "!"))
    
    mask <- image_read(mask_file)
    mask <- image_scale(mask, paste0(target_width, "x", target_height, "!"))
    return(list(img = img, mask = mask))
}

randomBSH <- function(img,
                      u = 0,
                      brightness_shift_lim = c(90, 110), # percentage
                      saturation_shift_lim = c(95, 105), # of current value
                      hue_shift_lim = c(80, 120)) {
    
    if (rnorm(1) < u) return(img)
    
    brightness_shift <- runif(1, 
                              brightness_shift_lim[1], 
                              brightness_shift_lim[2])
    saturation_shift <- runif(1, 
                              saturation_shift_lim[1], 
                              saturation_shift_lim[2])
    hue_shift <- runif(1, 
                       hue_shift_lim[1], 
                       hue_shift_lim[2])
    
    img <- image_modulate(img, 
                          brightness = brightness_shift, 
                          saturation =  saturation_shift, 
                          hue = hue_shift)
    img
}

img2arr <- function(image, 
                    target_width = 128,
                    target_height = 128) {
    result <- aperm(as.numeric(image[[1]])[, , 1:3], c(2, 1, 3)) # transpose
    array_reshape(result, c(1, target_width, target_height, 3))
}

mask2arr <- function(mask,
                     target_width = 128,
                     target_height = 128) {
    result <- t(as.numeric(mask[[1]])[, , 1]) # transpose
    array_reshape(result, c(1, target_width, target_height, 1))
}


# Iterators with parallel processing -----------------------------------------------------

cl <- makePSOCKcluster(4) 
clusterEvalQ(cl, {
    library(magick)     
    library(abind)     
    library(reticulate)
    
    imagesRead <- function(image_file,
                           mask_file,
                           target_width = 128, 
                           target_height = 128) {
        img <- image_read(image_file)
        img <- image_scale(img, paste0(target_width, "x", target_height, "!"))
        
        mask <- image_read(mask_file)
        mask <- image_scale(mask, paste0(target_width, "x", target_height, "!"))
        return(list(img = img, mask = mask))
    }
    
    randomBSH <- function(img,
                          u = 0,
                          brightness_shift_lim = c(90, 110), # percentage
                          saturation_shift_lim = c(95, 105), # of current value
                          hue_shift_lim = c(80, 120)) {
        
        if (rnorm(1) < u) return(img)
        
        brightness_shift <- runif(1, 
                                  brightness_shift_lim[1], 
                                  brightness_shift_lim[2])
        saturation_shift <- runif(1, 
                                  saturation_shift_lim[1], 
                                  saturation_shift_lim[2])
        hue_shift <- runif(1, 
                           hue_shift_lim[1], 
                           hue_shift_lim[2])
        
        img <- image_modulate(img, 
                              brightness = brightness_shift, 
                              saturation =  saturation_shift, 
                              hue = hue_shift)
        img
    }
    
    img2arr <- function(image, 
                        target_width = 128,
                        target_height = 128) {
        result <- aperm(as.numeric(image[[1]])[, , 1:3], c(2, 1, 3)) # transpose
        array_reshape(result, <- c(1, target_width, target_height, 3))
    }
    
    mask2arr <- function(mask,
                         target_width = 128,
                         target_height = 128) {
        result <- t(as.numeric(mask[[1]])[, , 1]) # transpose
        array_reshape(result, <- c(1, target_width, target_height, 1))
    }
})

registerDoParallel(cl)


train_generator <- function(images_dir, 
                            samples_index,
                            masks_dir, 
                            batch_size) {
    images_iter <- list.files(images_dir, 
                              pattern = ".jpg", 
                              full.names = TRUE)[samples_index] # for current epoch
    images_all <- list.files(images_dir, 
                             pattern = ".jpg",
                             full.names = TRUE)[samples_index]  # for next epoch
    masks_iter <- list.files(masks_dir, 
                             pattern = ".gif",
                             full.names = TRUE)[samples_index] # for current epoch
    masks_all <- list.files(masks_dir, 
                            pattern = ".gif",
                            full.names = TRUE)[samples_index] # for next epoch
    
    function() {
        
        # start new epoch
        if (length(images_iter) < batch_size) {
            images_iter <<- images_all
            masks_iter <<- masks_all
        }
        
        batch_ind <- sample(1:length(images_iter), batch_size)
        
        batch_images_list <- images_iter[batch_ind]
        images_iter <<- images_iter[-batch_ind]
        batch_masks_list <- masks_iter[batch_ind]
        masks_iter <<- masks_iter[-batch_ind]
        
        
        x_y_batch <- foreach(i = 1:batch_size) %dopar% {
            x_y_imgs <- imagesRead(image_file = batch_images_list[i],
                                   mask_file = batch_masks_list[i])
            # augmentation
            x_y_imgs$img <- randomBSH(x_y_imgs$img)
            # return as arrays
            x_y_arr <- list(x = img2arr(x_y_imgs$img),
                            y = mask2arr(x_y_imgs$mask))
        }
        
        x_y_batch <- purrr::transpose(x_y_batch)
        
        x_batch <- do.call(abind, c(x_y_batch$x, list(along = 1)))
        
        y_batch <- do.call(abind, c(x_y_batch$y, list(along = 1)))
        
        result <- list(keras_array(x_batch), 
                       keras_array(y_batch))
        return(result)
    }
}

val_generator <- function(images_dir, 
                          samples_index,
                          masks_dir, 
                          batch_size) {
    images_iter <- list.files(images_dir, 
                              pattern = ".jpg", 
                              full.names = TRUE)[samples_index] # for current epoch
    images_all <- list.files(images_dir, 
                             pattern = ".jpg",
                             full.names = TRUE)[samples_index]  # for next epoch
    masks_iter <- list.files(masks_dir, 
                             pattern = ".gif",
                             full.names = TRUE)[samples_index] # for current epoch
    masks_all <- list.files(masks_dir, 
                            pattern = ".gif",
                            full.names = TRUE)[samples_index] # for next epoch
    
    function() {
        
        # start new epoch
        if (length(images_iter) < batch_size) {
            images_iter <<- images_all
            masks_iter <<- masks_all
        }
        
        batch_ind <- sample(1:length(images_iter), batch_size)
        
        batch_images_list <- images_iter[batch_ind]
        images_iter <<- images_iter[-batch_ind]
        batch_masks_list <- masks_iter[batch_ind]
        masks_iter <<- masks_iter[-batch_ind]
        
        
        x_y_batch <- foreach(i = 1:batch_size) %dopar% {
            x_y_imgs <- imagesRead(image_file = batch_images_list[i],
                                   mask_file = batch_masks_list[i])
            # without augmentation
            
            # return as arrays
            x_y_arr <- list(x = img2arr(x_y_imgs$img),
                            y = mask2arr(x_y_imgs$mask))
        }
        
        x_y_batch <- purrr::transpose(x_y_batch)
        
        x_batch <- do.call(abind, c(x_y_batch$x, list(along = 1)))
        
        y_batch <- do.call(abind, c(x_y_batch$y, list(along = 1)))
        
        result <- list(keras_array(x_batch), 
                       keras_array(y_batch))
        return(result)
    }
}

train_iterator <- py_iterator(train_generator(images_dir = images_dir,
                                              masks_dir = masks_dir,
                                              samples_index = train_index,
                                              batch_size = batch_size))

val_iterator <- py_iterator(val_generator(images_dir = images_dir,
                                          masks_dir = masks_dir,
                                          samples_index = val_index,
                                          batch_size = batch_size))


# Training -----------------------------------------------------

tensorboard("logs_r")

callbacks_list <- list(
    callback_tensorboard("logs_r"),
    callback_early_stopping(monitor = "val_python_function",
                            min_delta = 1e-4,
                            patience = 8,
                            verbose = 1,
                            mode = "max"),
    callback_reduce_lr_on_plateau(monitor = "val_python_function",
                                  factor = 0.1,
                                  patience = 4,
                                  verbose = 1,
                                  epsilon = 1e-4,
                                  mode = "max"),
    callback_model_checkpoint(filepath = "weights_r/unet128_{epoch:02d}.h5",
                              monitor = "val_python_function",
                              save_best_only = TRUE,
                              save_weights_only = TRUE, 
                              mode = "max" )
)

model %>% fit_generator(
    train_iterator,
    steps_per_epoch = as.integer(length(train_index) / batch_size), 
    epochs = epochs, 
    validation_data = val_iterator,
    validation_steps = as.integer(length(val_index) / batch_size),
    verbose = 1,
    callbacks = callbacks_list
)

stopCluster(cl)
